{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.9775677 ,  0.2106216 ,  0.19077969], dtype=float32)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "    def __init__(self):\n",
    "        env = gym.make('Pendulum-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        done = terminated or truncated\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            done = True\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGiCAYAAABd6zmYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAjiElEQVR4nO3df3DU9YH/8dduNrvkB7shgewSSQpXGDHDj6uAYftj7Bw5Ypu2cuKNdTiPeoyOXHBAHOfkTnHa60wY/X77w57i3dyc+keVG26KrZy0pkFDq+FXJBpR44+iScVN+GF2QyCbZPd9f1C2LgbNj0+y74TnY2Zn5PN57zvv/UjyZHc/+azLGGMEAICF3JleAAAAl0KkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWylikHnnkEc2ePVtTpkxRRUWFDh48mKmlAAAslZFI/fd//7c2b96sBx54QK+88ooWL16sqqoqdXZ2ZmI5AABLuTJxgdmKigotW7ZM//Zv/yZJSiaTKi0t1Z133ql77713vJcDALCUZ7y/YF9fn5qamrRly5bUNrfbrcrKSjU2Ng56n3g8rng8nvpzMpnU6dOnVVRUJJfLNeZrBgA4yxij7u5ulZSUyO2+9It64x6pkydPKpFIKBgMpm0PBoN66623Br1PbW2tvv/974/H8gAA46i9vV2zZs265P5xj9RIbNmyRZs3b079ORqNqqysTO3t7fL7/RlcGQBgJGKxmEpLSzV16tTPHDfukZo+fbqysrLU0dGRtr2jo0OhUGjQ+/h8Pvl8vk9t9/v9RAoAJrDPe8tm3M/u83q9WrJkierr61Pbksmk6uvrFQ6Hx3s5AACLZeTlvs2bN2vt2rVaunSprrnmGv3kJz9RT0+Pbr311kwsBwBgqYxE6qabbtKJEye0detWRSIR/eVf/qV+/etff+pkCgDA5S0jvyc1WrFYTIFAQNFolPekAGACGurPca7dBwCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaw47Uvn379O1vf1slJSVyuVx65pln0vYbY7R161bNnDlTOTk5qqys1DvvvJM25vTp01qzZo38fr8KCgq0bt06nTlzZlQPBAAw+Qw7Uj09PVq8eLEeeeSRQfc/+OCDevjhh/XYY4/pwIEDysvLU1VVlXp7e1Nj1qxZo6NHj6qurk67d+/Wvn37dPvtt4/8UQAAJiczCpLMrl27Un9OJpMmFAqZhx56KLWtq6vL+Hw+8/TTTxtjjHnjjTeMJHPo0KHUmD179hiXy2U+/PDDIX3daDRqJJloNDqa5QMAMmSoP8cdfU/q2LFjikQiqqysTG0LBAKqqKhQY2OjJKmxsVEFBQVaunRpakxlZaXcbrcOHDgw6LzxeFyxWCztBgCY/ByNVCQSkSQFg8G07cFgMLUvEomouLg4bb/H41FhYWFqzMVqa2sVCARSt9LSUieXDQCw1IQ4u2/Lli2KRqOpW3t7e6aXBAAYB45GKhQKSZI6OjrStnd0dKT2hUIhdXZ2pu0fGBjQ6dOnU2Mu5vP55Pf7024AgMnP0UjNmTNHoVBI9fX1qW2xWEwHDhxQOByWJIXDYXV1dampqSk1Zu/evUomk6qoqHByOQCACc4z3DucOXNG7777burPx44dU3NzswoLC1VWVqZNmzbphz/8oebNm6c5c+bo/vvvV0lJiVatWiVJuuqqq3Tdddfptttu02OPPab+/n5t2LBB3/3ud1VSUuLYAwMATALDPW3whRdeMJI+dVu7dq0x5vxp6Pfff78JBoPG5/OZFStWmNbW1rQ5Tp06ZW6++WaTn59v/H6/ufXWW013d7fjpy4CAOw01J/jLmOMyWAjRyQWiykQCCgajfL+FABMQEP9OT4hzu4DAFyeiBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYaVqRqa2u1bNkyTZ06VcXFxVq1apVaW1vTxvT29qqmpkZFRUXKz8/X6tWr1dHRkTamra1N1dXVys3NVXFxse655x4NDAyM/tEAACaVYUWqoaFBNTU12r9/v+rq6tTf36+VK1eqp6cnNeauu+7Ss88+q507d6qhoUHHjx/XDTfckNqfSCRUXV2tvr4+vfzyy3ryySf1xBNPaOvWrc49KgDA5GBGobOz00gyDQ0Nxhhjurq6THZ2ttm5c2dqzJtvvmkkmcbGRmOMMc8995xxu90mEomkxmzfvt34/X4Tj8eH9HWj0aiRZKLR6GiWDwDIkKH+HB/Ve1LRaFSSVFhYKElqampSf3+/KisrU2Pmz5+vsrIyNTY2SpIaGxu1cOFCBYPB1JiqqirFYjEdPXp00K8Tj8cVi8XSbgCAyW/EkUomk9q0aZO+8pWvaMGCBZKkSCQir9ergoKCtLHBYFCRSCQ15pOBurD/wr7B1NbWKhAIpG6lpaUjXTYAYAIZcaRqamr0+uuva8eOHU6uZ1BbtmxRNBpN3drb28f8awIAMs8zkjtt2LBBu3fv1r59+zRr1qzU9lAopL6+PnV1daU9m+ro6FAoFEqNOXjwYNp8F87+uzDmYj6fTz6fbyRLBQBMYMN6JmWM0YYNG7Rr1y7t3btXc+bMSdu/ZMkSZWdnq76+PrWttbVVbW1tCofDkqRwOKyWlhZ1dnamxtTV1cnv96u8vHw0jwUAMMkM65lUTU2NnnrqKf3yl7/U1KlTU+8hBQIB5eTkKBAIaN26ddq8ebMKCwvl9/t15513KhwOa/ny5ZKklStXqry8XLfccosefPBBRSIR3XfffaqpqeHZEgAgjcsYY4Y82OUadPvjjz+u733ve5LO/zLv3XffraefflrxeFxVVVV69NFH017K++CDD7R+/Xq9+OKLysvL09q1a7Vt2zZ5PENrZiwWUyAQUDQald/vH+ryAQCWGOrP8WFFyhZECgAmtqH+HOfafQAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFqeTC8AwJ8ZYy65z+VyjeNKADsQKcACJpHQQHe3Yq+8oq5Dh9Tb3q7EuXPy+P3KmztX0776VeV+8YvKyssjVrisECkgw5LxuLr271fHs8/q7DvvSJ94NtV/4oTOvfeeTr3wggJXX63iVauUf9VVhAqXDSIFZJAxRieef16RnTs10NV16XF9ferav1+9H32ksttvV/6CBYQKlwVOnAAyxCQSOvXb3+r4z3/+mYH6pN4PPlDbf/yHzrz55me+fwVMFkQKyJCet99WZOdOJc+eTdv+YU+Pdre36+k//EG/PX5cPf39aft7P/hAHz39tBJnzozncoGM4OU+IAOS/f2KHj6seCSS2maM0bEzZ/TAkSN6/8wZ9SYS8mdna8G0afp/y5Yp2/3nf1N2v/qqzr73nqYuXszLfpjUeCYFjDNjjGLNzer4xS/Stv/hzBnd9tJLejMa1blEQkZStL9fL3V2auOBAzrV25s2vu2xx8Zx1UBmEClgnCV6evTBz34mk0ikbf/J0aOKXvTS3gUHT55U3fHjadvMJcYCkwmRAsbZuWPHCAwwREQKGGen9u5Voqcn08sAJgQiBYwjk0jIDAwMuq+6tFTZlzgJYnZ+vhYVFo7l0gArcXYfMI76P/5Y/R9/POi+qpISSdIPX31VfYmEkpKyXC4VeL36/8uW6Qv5+eO4UsAORAoYR90tLep+7bVB97lcLlWVlGhWbq52//GPOtXbq9n5+bppzhwV+XyfGu/7U9SAyYxIAePpc64S4XK5tGDaNC2YNu1zpyq+/nqnVgVYi/ekgHFiEolLvtQ3ElOuuIJf5MWkR6SAcZLo7VX0yBFH5souKpJ7kJcAgcmGSAHjJNnbqzOXeD9quPxLligrN9eRuQCbESlgnCT7+hybK6esTG6v17H5AFsRKWCc9LzzjmNzZU2ZIpebb19MfvwtB8bJ6b17HZnHGwrJd8UVjswF2I5IAePAJJOKX3SB2JGaUlKinLIyR+YCbEekgHHQd/Kkkpe4HNJwZeXlyTN1qiNzAbYjUsA4+HjfPg1Eo85Mxu9G4TJCpIBxEO/sdOTjOdy5uSr+znccWBEwMRApYIwl+/sdO/3c7fHwfhQuK0QKGGP9XV3qP3nSmcmysrjSBC4rRAoYY2fffVfdLS2OzFVcXe3IPMBEQaSAMWSMOf9e1Odc/Xyo8ubPd2QeYKIgUsBYSiYVj0Qcmy67oMCxuYCJgEgBYyjZ33/JDzkcruxp0+T2evl4DlxWiBQwhkx/v2PvRxV8+cvy8EwKlxkiBYyh/q4ux+byBYOc2YfLDpECxtDHL73k2EkTrqwsXurDZYdIAWOo66WXHJnHW1ysnNmzHZkLmEiIFDBGTCIhZ55DSb6ZM5U7b55DswETB5ECxkjv8eNKnjvnyFxur5f3o3BZIlLAGIkdOeLIiROurCzlL1jA+1G4LBEpYAwYY9R34oSMAxeWdWVnK7BsmQOrAiYeIgWMgWRvrxLd3c5M5nZrysyZzswFTDBEChgDvW1tijU3OzIXL/LhckakgDEw0NOj/tOnHZkreOONkptvVVye+JsPOMwY49hZfZKUd+WVjs0FTDRECnBaMqmz77/v2HRZOTmc2YfL1rAitX37di1atEh+v19+v1/hcFh79uxJ7e/t7VVNTY2KioqUn5+v1atXq6OjI22OtrY2VVdXKzc3V8XFxbrnnns0MDDgzKMBLJDs61PnM884Mpf/6qvlLSpyZC5gIhpWpGbNmqVt27apqalJhw8f1l/91V/p+uuv19GjRyVJd911l5599lnt3LlTDQ0NOn78uG644YbU/ROJhKqrq9XX16eXX35ZTz75pJ544glt3brV2UcFZJBJJmUc+ofXlFmzlJWX58hcwETkMmZ0V78sLCzUQw89pBtvvFEzZszQU089pRtvvFGS9NZbb+mqq65SY2Ojli9frj179uhb3/qWjh8/rmAwKEl67LHH9E//9E86ceKEvF7vkL5mLBZTIBBQNBqV3+8fzfIBx/W8957euvtuKZkc9Vwlf/d3Cv3t3/JyHyadof4cH/F7UolEQjt27FBPT4/C4bCamprU39+vysrK1Jj58+errKxMjY2NkqTGxkYtXLgwFShJqqqqUiwWSz0bG0w8HlcsFku7AbaKNTU5cuVzl88nj99PoHBZG3akWlpalJ+fL5/PpzvuuEO7du1SeXm5IpGIvF6vCi76ULZgMKjInz4+OxKJpAXqwv4L+y6ltrZWgUAgdSstLR3usoFxEz10yJFITZk5U1MXL3ZgRcDENexIXXnllWpubtaBAwe0fv16rV27Vm+88cZYrC1ly5YtikajqVt7e/uYfj1gpBLnzjn2flRWfr58M2Y4MhcwUXmGewev16u5c+dKkpYsWaJDhw7ppz/9qW666Sb19fWpq6sr7dlUR0eHQqGQJCkUCungwYNp8104++/CmMH4fD75uAI0JoCe1lb1f/yxI3O5XC65PMP+FgUmlVH/nlQymVQ8HteSJUuUnZ2t+vr61L7W1la1tbUpHA5LksLhsFpaWtTZ2ZkaU1dXJ7/fr/Ly8tEuBci4M0ePOnOlCbdbU6++evTzABPcsP6ZtmXLFn3jG99QWVmZuru79dRTT+nFF1/Ub37zGwUCAa1bt06bN29WYWGh/H6/7rzzToXDYS1fvlyStHLlSpWXl+uWW27Rgw8+qEgkovvuu081NTU8U8KEZ4zRKE+WTXFlZanwa19zZC5gIhtWpDo7O/X3f//3+uijjxQIBLRo0SL95je/0V//9V9Lkn784x/L7XZr9erVisfjqqqq0qOPPpq6f1ZWlnbv3q3169crHA4rLy9Pa9eu1Q9+8ANnHxWQAYmeHvWdOOHMZC6Xsi86CQm4HI3696Qygd+Tgo3OHjumd//1X9V/8uSo55r6pS9p3v33854UJq0x/z0pAH9mjFGip8eRQElSQUUFVz4HRKQAxzh1Vp8k5ZSWSvwSL0CkAEcYo7PvvefcfG43V5oARKQAR5hEQieff96RuQJLl/Jx8cCfECnAAcl4XMm+Pkfm8l1xhbI4IQiQRKQAR5xra3PkqueS5J4yRW7O6gMkESnAEZGdOx25Zp8rO1ue/HwHVgRMDkQKcIDp73dkHt/MmZrGlSaAFCIFjNJALKZkb68jc7l9Pq40AXwCkQJG6eyxY+pz4qKykrxFRXLxS7xACt8NwCidfecdZ6404XZrelXV6OcBJhEiBYyCSSSUjMedmczlUu6fPqsNwHlEChiFgVhMZ1pbHZvP7fU6NhcwGRApYBT6T59Wd3OzI3NNXbBArqwsR+YCJgsiBYyQMUbGoV/glaRpX/0qH80BXIRIAaPQ+8c/OjaXd/p0rnwOXIRIASNljE7W1TkylTcUkqeggCufAxchUsBIGaNzf/iDI1Plzpkj74wZjswFTCZEChihge5uGWMcmcs7fTrX7AMGQaSAETp77JhMIuHMZFlZXGkCGATfFcAInfrtb2Uc+Awpz7RpKqiocGBFwORDpIARSA4MKHH2rCNzZeXmKmf2bEfmAiYbIgWMwEBXlxI9PY7M5fJ45MnLc2QuYLIhUsAIRA8fVs/bbzsyV3Yg4Mg8wGREpIARMAMDjn1c/MzvfteReYDJiEgBw5QcGNBALObMZC6XfCUlzswFTEJEChimRE+PY1c+9/j9XPkc+AxEChimRHe3uo8ccWSuaddeK7fP58hcwGREpIBhMMYo4dSHHOr85ZC48jlwaUQKGKaz777r2FxZOTmOzQVMRkQKGKaPX3rJkXm8waCyi4q48jnwGYgUMBzG6NyxY45MlfvFL2rKFVc4MhcwWREpYBj6Tpxw7KKyHr9fWVxpAvhMRAoYhtO/+52Svb2OzcdLfcBnI1LAMJx7//3zV5sYJY/fr6Kvf330CwImOSIFDEP+lVfKlZ096nncOTnKnTvXgRUBkxuRAoZhxre+pazc3FHP43K5uNIEMARECsiA6dddl+klABMCkQIyYOrChZleAjAhEClguBw4Iy+7sNCBhQCTH5EChmnW9743qvsHV62SZ+pUZxYDTHJEChgGl8ulqYsXK2fOnBHd3ztjhgLXXMNJE8AQESlgmLKnTdPMm29WdlHRsO7n8vlUfP31yrvyyjFaGTD5EClgmFxutwqWLdOsW28d8ntL7txchVav1oxvflNuB37PCrhc8EE2wAi4srI07Wtfk8vjUeevfqWet98e/EoULpdyvvAFTa+q0vTKSrn57ChgWPiOAUbI5XKpoKJCuX/xF+puadHHL7+sc++/r4HubmXl5GjKrFnyL1migmuukW/mTJ5BASNApIBRcGVlyRcKyVtcrMKvf10y5vxNktxuubKyJJeLC8kCI0SkAAe43G653LzFCziN7yoAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1hpVpLZt2yaXy6VNmzaltvX29qqmpkZFRUXKz8/X6tWr1dHRkXa/trY2VVdXKzc3V8XFxbrnnns0MNhHbwMALmsjjtShQ4f07//+71q0aFHa9rvuukvPPvusdu7cqYaGBh0/flw33HBDan8ikVB1dbX6+vr08ssv68knn9QTTzyhrVu3jvxRAAAmJzMC3d3dZt68eaaurs5ce+21ZuPGjcYYY7q6ukx2drbZuXNnauybb75pJJnGxkZjjDHPPfeccbvdJhKJpMZs377d+P1+E4/Hh/T1o9GokWSi0ehIlg8AyLCh/hwf0TOpmpoaVVdXq7KyMm17U1OT+vv707bPnz9fZWVlamxslCQ1NjZq4cKFCgaDqTFVVVWKxWI6evTooF8vHo8rFoul3QAAk59nuHfYsWOHXnnlFR06dOhT+yKRiLxerwoKCtK2B4NBRSKR1JhPBurC/gv7BlNbW6vvf//7w10qAGCCG9Yzqfb2dm3cuFE///nPNWXKlLFa06ds2bJF0Wg0dWtvbx+3rw0AyJxhRaqpqUmdnZ26+uqr5fF45PF41NDQoIcfflgej0fBYFB9fX3q6upKu19HR4dCoZAkKRQKfepsvwt/vjDmYj6fT36/P+0GAJj8hhWpFStWqKWlRc3Nzanb0qVLtWbNmtR/Z2dnq76+PnWf1tZWtbW1KRwOS5LC4bBaWlrU2dmZGlNXVye/36/y8nKHHhYAYDIY1ntSU6dO1YIFC9K25eXlqaioKLV93bp12rx5swoLC+X3+3XnnXcqHA5r+fLlkqSVK1eqvLxct9xyix588EFFIhHdd999qqmpkc/nc+hhAQAmg2GfOPF5fvzjH8vtdmv16tWKx+OqqqrSo48+mtqflZWl3bt3a/369QqHw8rLy9PatWv1gx/8wOmlAAAmOJcxxmR6EcMVi8UUCAQUjUZ5fwoAJqCh/hzn2n0AAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGt5Mr2AkTDGSJJisViGVwIAGIkLP78v/Dy/lAkZqVOnTkmSSktLM7wSAMBodHd3KxAIXHL/hIxUYWGhJKmtre0zH9zlLhaLqbS0VO3t7fL7/ZlejrU4TkPDcRoajtPQGGPU3d2tkpKSzxw3ISPldp9/Ky0QCPCXYAj8fj/HaQg4TkPDcRoajtPnG8qTDE6cAABYi0gBAKw1ISPl8/n0wAMPyOfzZXopVuM4DQ3HaWg4TkPDcXKWy3ze+X8AAGTIhHwmBQC4PBApAIC1iBQAwFpECgBgrQkZqUceeUSzZ8/WlClTVFFRoYMHD2Z6SeNq3759+va3v62SkhK5XC4988wzafuNMdq6datmzpypnJwcVVZW6p133kkbc/r0aa1Zs0Z+v18FBQVat26dzpw5M46PYmzV1tZq2bJlmjp1qoqLi7Vq1Sq1tramjent7VVNTY2KioqUn5+v1atXq6OjI21MW1ubqqurlZubq+LiYt1zzz0aGBgYz4cyprZv365FixalfvE0HA5rz549qf0co8Ft27ZNLpdLmzZtSm3jWI0RM8Hs2LHDeL1e81//9V/m6NGj5rbbbjMFBQWmo6Mj00sbN88995z5l3/5F/OLX/zCSDK7du1K279t2zYTCATMM888Y1599VXzne98x8yZM8ecO3cuNea6664zixcvNvv37ze/+93vzNy5c83NN988zo9k7FRVVZnHH3/cvP7666a5udl885vfNGVlZebMmTOpMXfccYcpLS019fX15vDhw2b58uXmy1/+cmr/wMCAWbBggamsrDRHjhwxzz33nJk+fbrZsmVLJh7SmPjVr35l/vd//9e8/fbbprW11fzzP/+zyc7ONq+//roxhmM0mIMHD5rZs2ebRYsWmY0bN6a2c6zGxoSL1DXXXGNqampSf04kEqakpMTU1tZmcFWZc3GkksmkCYVC5qGHHkpt6+rqMj6fzzz99NPGGGPeeOMNI8kcOnQoNWbPnj3G5XKZDz/8cNzWPp46OzuNJNPQ0GCMOX9MsrOzzc6dO1Nj3nzzTSPJNDY2GmPO/2PA7XabSCSSGrN9+3bj9/tNPB4f3wcwjqZNm2b+8z//k2M0iO7ubjNv3jxTV1dnrr322lSkOFZjZ0K93NfX16empiZVVlamtrndblVWVqqxsTGDK7PHsWPHFIlE0o5RIBBQRUVF6hg1NjaqoKBAS5cuTY2prKyU2+3WgQMHxn3N4yEajUr688WJm5qa1N/fn3ac5s+fr7KysrTjtHDhQgWDwdSYqqoqxWIxHT16dBxXPz4SiYR27Nihnp4ehcNhjtEgampqVF1dnXZMJP4+jaUJdYHZkydPKpFIpP1PlqRgMKi33norQ6uySyQSkaRBj9GFfZFIRMXFxWn7PR6PCgsLU2Mmk2QyqU2bNukrX/mKFixYIOn8MfB6vSooKEgbe/FxGuw4Xtg3WbS0tCgcDqu3t1f5+fnatWuXysvL1dzczDH6hB07duiVV17RoUOHPrWPv09jZ0JFChiJmpoavf766/r973+f6aVY6corr1Rzc7Oi0aj+53/+R2vXrlVDQ0Oml2WV9vZ2bdy4UXV1dZoyZUqml3NZmVAv902fPl1ZWVmfOmOmo6NDoVAoQ6uyy4Xj8FnHKBQKqbOzM23/wMCATp8+PemO44YNG7R792698MILmjVrVmp7KBRSX1+furq60sZffJwGO44X9k0WXq9Xc+fO1ZIlS1RbW6vFixfrpz/9KcfoE5qamtTZ2amrr75aHo9HHo9HDQ0Nevjhh+XxeBQMBjlWY2RCRcrr9WrJkiWqr69PbUsmk6qvr1c4HM7gyuwxZ84chUKhtGMUi8V04MCB1DEKh8Pq6upSU1NTaszevXuVTCZVUVEx7mseC8YYbdiwQbt27dLevXs1Z86ctP1LlixRdnZ22nFqbW1VW1tb2nFqaWlJC3pdXZ38fr/Ky8vH54FkQDKZVDwe5xh9wooVK9TS0qLm5ubUbenSpVqzZk3qvzlWYyTTZ24M144dO4zP5zNPPPGEeeONN8ztt99uCgoK0s6Ymey6u7vNkSNHzJEjR4wk86Mf/cgcOXLEfPDBB8aY86egFxQUmF/+8pfmtddeM9dff/2gp6B/6UtfMgcOHDC///3vzbx58ybVKejr1683gUDAvPjii+ajjz5K3c6ePZsac8cdd5iysjKzd+9ec/jwYRMOh004HE7tv3DK8MqVK01zc7P59a9/bWbMmDGpThm+9957TUNDgzl27Jh57bXXzL333mtcLpd5/vnnjTEco8/yybP7jOFYjZUJFyljjPnZz35mysrKjNfrNddcc43Zv39/ppc0rl544QUj6VO3tWvXGmPOn4Z+//33m2AwaHw+n1mxYoVpbW1Nm+PUqVPm5ptvNvn5+cbv95tbb73VdHd3Z+DRjI3Bjo8k8/jjj6fGnDt3zvzjP/6jmTZtmsnNzTV/8zd/Yz766KO0ed5//33zjW98w+Tk5Jjp06ebu+++2/T394/zoxk7//AP/2C+8IUvGK/Xa2bMmGFWrFiRCpQxHKPPcnGkOFZjg4/qAABYa0K9JwUAuLwQKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYK3/AzhV1hssQ/qXAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "#打印游戏\n",
    "def show():\n",
    "    plt.imshow(env.render())\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Sequential(\n",
       "   (0): Linear(in_features=3, out_features=128, bias=True)\n",
       "   (1): ReLU()\n",
       "   (2): Linear(in_features=128, out_features=16, bias=True)\n",
       " ),\n",
       " Sequential(\n",
       "   (0): Linear(in_features=3, out_features=128, bias=True)\n",
       "   (1): ReLU()\n",
       "   (2): Linear(in_features=128, out_features=16, bias=True)\n",
       " ))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## 搭建环境\n",
    "import torch\n",
    "model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(3,128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128,16),\n",
    ")\n",
    "#经验网络,用于评估一个状态的分数\n",
    "next_model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(3, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 16),\n",
    ")\n",
    "#把model的参数复制给next_model\n",
    "next_model.load_state_dict(model.state_dict())\n",
    "\n",
    "model, next_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5, -0.6666666666666667)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "def get_action(state):\n",
    "    state = torch.FloatTensor(state).reshape(1,3)\n",
    "    action = model(state).argmax().item()\n",
    "\n",
    "    if random.random() < 0.01:\n",
    "        action = random.choice(range(16))\n",
    "    #离散动作连续化\n",
    "    action_continuous = action\n",
    "    action_continuous /=15\n",
    "    action_continuous *= 4\n",
    "    action_continuous -=2\n",
    "    return action,action_continuous\n",
    "get_action([0,0,0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((200, 0), 200)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#样本池\n",
    "datas = []\n",
    "\n",
    "\n",
    "#向样本池中添加N条数据,删除M条最古老的数据\n",
    "def update_data():\n",
    "    old_count = len(datas)\n",
    "\n",
    "    #玩到新增了N个数据为止\n",
    "    while len(datas) - old_count < 200:\n",
    "        #初始化游戏\n",
    "        state = env.reset()\n",
    "\n",
    "        #玩到游戏结束为止\n",
    "        over = False\n",
    "        while not over:\n",
    "            #根据当前状态得到一个动作\n",
    "            action, action_continuous = get_action(state)\n",
    "\n",
    "            #执行动作,得到反馈\n",
    "            next_state, reward, over, _ = env.step([action_continuous])\n",
    "\n",
    "            #记录数据样本\n",
    "            datas.append((state, action, reward, next_state, over))\n",
    "\n",
    "            #更新游戏状态,开始下一个动作\n",
    "            state = next_state\n",
    "\n",
    "    update_count = len(datas) - old_count\n",
    "    drop_count = max(len(datas) - 5000, 0)\n",
    "\n",
    "    #数据上限,超出时从最古老的开始删除\n",
    "    while len(datas) > 5000:\n",
    "        datas.pop(0)\n",
    "\n",
    "    return update_count, drop_count\n",
    "\n",
    "\n",
    "update_data(), len(datas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\cgq10\\AppData\\Local\\Temp\\ipykernel_23440\\3735624114.py:7: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\torch\\csrc\\utils\\tensor_new.cpp:248.)\n",
      "  state = torch.FloatTensor([i[0] for i in samples]).reshape(-1,3)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor([[ 0.9661, -0.2583, -0.8100],\n",
       "         [-0.7755, -0.6313, -8.0000],\n",
       "         [-0.0614,  0.9981, -6.9065],\n",
       "         [ 0.9832, -0.1824, -0.3489],\n",
       "         [-0.2182, -0.9759,  4.6394],\n",
       "         [ 0.9971,  0.0755, -4.5041],\n",
       "         [-0.9932,  0.1165, -8.0000],\n",
       "         [ 0.0303, -0.9995, -7.2322],\n",
       "         [ 0.9366, -0.3504, -1.0300],\n",
       "         [ 0.2293, -0.9734, -0.5761],\n",
       "         [ 0.4272,  0.9041, -5.9357],\n",
       "         [ 0.9414,  0.3373, -1.1833],\n",
       "         [ 0.6533, -0.7571, -5.6913],\n",
       "         [ 0.0976, -0.9952,  2.5157],\n",
       "         [-0.5957, -0.8032, -4.9150],\n",
       "         [ 0.8551,  0.5184,  1.8610],\n",
       "         [-0.0979, -0.9952, -2.9356],\n",
       "         [-0.9907, -0.1362, -6.6126],\n",
       "         [ 0.0375,  0.9993, -6.7429],\n",
       "         [-0.1053,  0.9944, -5.5986],\n",
       "         [ 0.1603, -0.9871, -1.4062],\n",
       "         [ 0.6001,  0.7999,  2.6012],\n",
       "         [ 0.9114,  0.4114, -4.6014],\n",
       "         [-0.8541,  0.5201,  6.8336],\n",
       "         [ 0.5677, -0.8232, -3.0286],\n",
       "         [ 0.6988,  0.7153,  1.9647],\n",
       "         [-0.8705, -0.4922,  6.9408],\n",
       "         [ 0.7233, -0.6906, -5.3946],\n",
       "         [ 0.7936,  0.6084, -4.9177],\n",
       "         [-0.2175,  0.9761, -7.1130],\n",
       "         [-0.9860, -0.1668,  7.1659],\n",
       "         [ 0.7884,  0.6151,  2.3498],\n",
       "         [ 0.9577,  0.2877,  0.3979],\n",
       "         [ 0.7691,  0.6391, -5.0982],\n",
       "         [-0.3956,  0.9184, -7.4553],\n",
       "         [ 0.2210, -0.9753,  0.8315],\n",
       "         [ 0.8628,  0.5056, -4.8254],\n",
       "         [ 0.2508,  0.9680, -6.2979],\n",
       "         [ 0.9757, -0.2190, -0.7458],\n",
       "         [ 0.8322,  0.5544,  0.3248],\n",
       "         [-0.0770,  0.9970,  5.1030],\n",
       "         [ 0.6746,  0.7382, -3.0543],\n",
       "         [ 0.9962,  0.0866, -0.9943],\n",
       "         [ 0.9375,  0.3480, -2.1025],\n",
       "         [-0.6390,  0.7692, -6.5905],\n",
       "         [-0.0401, -0.9992, -5.0824],\n",
       "         [ 0.1927, -0.9813,  1.9117],\n",
       "         [-0.9757,  0.2192, -7.9201],\n",
       "         [ 0.9551,  0.2962, -4.5862],\n",
       "         [ 0.6909,  0.7230,  2.9112],\n",
       "         [ 0.9888,  0.1494, -4.5430],\n",
       "         [-0.2951,  0.9555, -7.3195],\n",
       "         [ 0.9882, -0.1532, -4.5875],\n",
       "         [-0.6959, -0.7181, -8.0000],\n",
       "         [ 0.0609,  0.9981,  4.4865],\n",
       "         [ 0.2050,  0.9788, -4.6801],\n",
       "         [-0.6305,  0.7762,  6.3515],\n",
       "         [ 0.6216,  0.7833, -5.3652],\n",
       "         [ 0.5133,  0.8582, -5.7118],\n",
       "         [ 0.7722, -0.6354, -2.1260],\n",
       "         [-0.4685, -0.8835, -8.0000],\n",
       "         [-0.4997, -0.8662, -4.6808],\n",
       "         [-0.8486, -0.5291, -7.1127],\n",
       "         [-0.5835, -0.8121, -8.0000]]),\n",
       " tensor([[10],\n",
       "         [ 4],\n",
       "         [ 4],\n",
       "         [ 1],\n",
       "         [ 0],\n",
       "         [ 4],\n",
       "         [ 4],\n",
       "         [ 4],\n",
       "         [ 9],\n",
       "         [ 5],\n",
       "         [ 4],\n",
       "         [10],\n",
       "         [ 4],\n",
       "         [ 5],\n",
       "         [ 4],\n",
       "         [10],\n",
       "         [ 4],\n",
       "         [ 4],\n",
       "         [ 4],\n",
       "         [ 4],\n",
       "         [ 5],\n",
       "         [10],\n",
       "         [ 4],\n",
       "         [ 5],\n",
       "         [ 9],\n",
       "         [10],\n",
       "         [ 5],\n",
       "         [ 4],\n",
       "         [ 4],\n",
       "         [ 4],\n",
       "         [ 5],\n",
       "         [10],\n",
       "         [10],\n",
       "         [ 4],\n",
       "         [ 4],\n",
       "         [ 5],\n",
       "         [ 4],\n",
       "         [ 4],\n",
       "         [10],\n",
       "         [10],\n",
       "         [ 5],\n",
       "         [ 4],\n",
       "         [10],\n",
       "         [10],\n",
       "         [ 4],\n",
       "         [ 4],\n",
       "         [ 5],\n",
       "         [ 4],\n",
       "         [ 4],\n",
       "         [ 5],\n",
       "         [ 4],\n",
       "         [ 4],\n",
       "         [ 4],\n",
       "         [ 4],\n",
       "         [ 5],\n",
       "         [ 4],\n",
       "         [ 5],\n",
       "         [ 4],\n",
       "         [ 4],\n",
       "         [ 9],\n",
       "         [ 4],\n",
       "         [ 4],\n",
       "         [ 4],\n",
       "         [ 4]]),\n",
       " tensor([[ -0.1343],\n",
       "         [-12.4443],\n",
       "         [ -7.4349],\n",
       "         [ -0.0488],\n",
       "         [ -5.3632],\n",
       "         [ -2.0352],\n",
       "         [-15.5506],\n",
       "         [ -7.6044],\n",
       "         [ -0.2344],\n",
       "         [ -1.8278],\n",
       "         [ -4.7996],\n",
       "         [ -0.2589],\n",
       "         [ -3.9776],\n",
       "         [ -2.8031],\n",
       "         [ -7.2957],\n",
       "         [ -0.6438],\n",
       "         [ -3.6478],\n",
       "         [-13.4031],\n",
       "         [ -6.8984],\n",
       "         [ -5.9454],\n",
       "         [ -2.1857],\n",
       "         [ -1.5366],\n",
       "         [ -2.2980],\n",
       "         [-11.4026],\n",
       "         [ -1.8526],\n",
       "         [ -1.0218],\n",
       "         [-11.7191],\n",
       "         [ -3.4921],\n",
       "         [ -2.8471],\n",
       "         [ -8.2646],\n",
       "         [-13.9804],\n",
       "         [ -0.9915],\n",
       "         [ -0.1015],\n",
       "         [ -3.0807],\n",
       "         [ -9.4696],\n",
       "         [ -1.8866],\n",
       "         [ -2.6103],\n",
       "         [ -5.7025],\n",
       "         [ -0.1048],\n",
       "         [ -0.3563],\n",
       "         [ -5.3200],\n",
       "         [ -1.6232],\n",
       "         [ -0.1068],\n",
       "         [ -0.5689],\n",
       "         [ -9.4698],\n",
       "         [ -5.1790],\n",
       "         [ -2.2617],\n",
       "         [-14.8035],\n",
       "         [ -2.1946],\n",
       "         [ -1.5009],\n",
       "         [ -2.0872],\n",
       "         [ -8.8568],\n",
       "         [ -2.1290],\n",
       "         [-11.8786],\n",
       "         [ -4.2929],\n",
       "         [ -4.0525],\n",
       "         [ -9.1105],\n",
       "         [ -3.6894],\n",
       "         [ -4.3278],\n",
       "         [ -0.9261],\n",
       "         [-10.6376],\n",
       "         [ -6.5768],\n",
       "         [-11.7375],\n",
       "         [-11.2135]]),\n",
       " tensor([[ 9.5341e-01, -3.0168e-01, -9.0374e-01],\n",
       "         [-9.6015e-01, -2.7948e-01, -8.0000e+00],\n",
       "         [ 2.5078e-01,  9.6805e-01, -6.2979e+00],\n",
       "         [ 9.7573e-01, -2.1898e-01, -7.4578e-01],\n",
       "         [-3.9564e-02, -9.9922e-01,  3.6075e+00],\n",
       "         [ 9.8819e-01, -1.5324e-01, -4.5875e+00],\n",
       "         [-8.6943e-01,  4.9406e-01, -8.0000e+00],\n",
       "         [-3.6130e-01, -9.3245e-01, -8.0000e+00],\n",
       "         [ 9.1326e-01, -4.0739e-01, -1.2328e+00],\n",
       "         [ 1.6032e-01, -9.8707e-01, -1.4062e+00],\n",
       "         [ 6.5282e-01,  7.5751e-01, -5.3976e+00],\n",
       "         [ 9.5457e-01,  2.9797e-01, -8.3028e-01],\n",
       "         [ 3.8200e-01, -9.2416e-01, -6.3991e+00],\n",
       "         [ 1.8025e-01, -9.8362e-01,  1.6692e+00],\n",
       "         [-7.9618e-01, -6.0506e-01, -5.6574e+00],\n",
       "         [ 7.8844e-01,  6.1511e-01,  2.3498e+00],\n",
       "         [-2.8517e-01, -9.5848e-01, -3.8220e+00],\n",
       "         [-9.7884e-01,  2.0461e-01, -6.8548e+00],\n",
       "         [ 3.3747e-01,  9.4134e-01, -6.1335e+00],\n",
       "         [ 1.4362e-01,  9.8963e-01, -4.9928e+00],\n",
       "         [ 4.8672e-02, -9.9881e-01, -2.2465e+00],\n",
       "         [ 4.6056e-01,  8.8763e-01,  3.3011e+00],\n",
       "         [ 9.7959e-01,  2.0098e-01, -4.4329e+00],\n",
       "         [-9.8186e-01,  1.8959e-01,  7.1237e+00],\n",
       "         [ 4.1183e-01, -9.1126e-01, -3.5860e+00],\n",
       "         [ 6.0015e-01,  7.9989e-01,  2.6012e+00],\n",
       "         [-6.6882e-01, -7.4342e-01,  6.4717e+00],\n",
       "         [ 4.8458e-01, -8.7475e-01, -6.0525e+00],\n",
       "         [ 9.1145e-01,  4.1141e-01, -4.6014e+00],\n",
       "         [ 1.0661e-01,  9.9430e-01, -6.5210e+00],\n",
       "         [-8.7049e-01, -4.9218e-01,  6.9408e+00],\n",
       "         [ 6.9089e-01,  7.2296e-01,  2.9112e+00],\n",
       "         [ 9.4683e-01,  3.2173e-01,  7.1373e-01],\n",
       "         [ 8.9811e-01,  4.3977e-01, -4.7589e+00],\n",
       "         [-6.1374e-02,  9.9811e-01, -6.9065e+00],\n",
       "         [ 2.2099e-01, -9.7528e-01,  6.8660e-05],\n",
       "         [ 9.5512e-01,  2.9622e-01, -4.5862e+00],\n",
       "         [ 5.1334e-01,  8.5818e-01, -5.7118e+00],\n",
       "         [ 9.6606e-01, -2.5830e-01, -8.1001e-01],\n",
       "         [ 8.0821e-01,  5.8890e-01,  8.4058e-01],\n",
       "         [-3.5659e-01,  9.3426e-01,  5.7508e+00],\n",
       "         [ 7.6593e-01,  6.4293e-01, -2.6407e+00],\n",
       "         [ 9.9898e-01,  4.5249e-02, -8.2937e-01],\n",
       "         [ 9.6420e-01,  2.6519e-01, -1.7415e+00],\n",
       "         [-3.7598e-01,  9.2663e-01, -6.1536e+00],\n",
       "         [-3.3228e-01, -9.4318e-01, -5.9718e+00],\n",
       "         [ 2.4517e-01, -9.6948e-01,  1.0757e+00],\n",
       "         [-8.1632e-01,  5.7759e-01, -7.8957e+00],\n",
       "         [ 9.9715e-01,  7.5458e-02, -4.5041e+00],\n",
       "         [ 5.6055e-01,  8.2812e-01,  3.3534e+00],\n",
       "         [ 9.9691e-01, -7.8557e-02, -4.5710e+00],\n",
       "         [ 3.7540e-02,  9.9930e-01, -6.7429e+00],\n",
       "         [ 9.2263e-01, -3.8570e-01, -4.8424e+00],\n",
       "         [-9.2062e-01, -3.9047e-01, -8.0000e+00],\n",
       "         [-1.9455e-01,  9.8089e-01,  5.1351e+00],\n",
       "         [ 3.9935e-01,  9.1680e-01, -4.0860e+00],\n",
       "         [-8.5413e-01,  5.2007e-01,  6.8336e+00],\n",
       "         [ 7.9361e-01,  6.0843e-01, -4.9177e+00],\n",
       "         [ 7.1700e-01,  6.9708e-01, -5.2082e+00],\n",
       "         [ 6.8543e-01, -7.2814e-01, -2.5425e+00],\n",
       "         [-7.7552e-01, -6.3132e-01, -8.0000e+00],\n",
       "         [-7.1509e-01, -6.9903e-01, -5.4705e+00],\n",
       "         [-9.8473e-01, -1.7412e-01, -7.6495e+00],\n",
       "         [-8.5366e-01, -5.2083e-01, -8.0000e+00]]),\n",
       " tensor([[0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0]]))"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#获取一批数据样本\n",
    "def get_sample():\n",
    "    #从样本池中采样\n",
    "    samples = random.sample(datas, 64)\n",
    "\n",
    "    #[b, 3]\n",
    "    state = torch.FloatTensor([i[0] for i in samples]).reshape(-1,3)\n",
    "    #[b, 1]\n",
    "    action = torch.LongTensor([i[1] for i in samples]).reshape(-1,1)\n",
    "    #[b, 1]\n",
    "    reward = torch.FloatTensor([i[2] for i in samples]).reshape(-1,1)\n",
    "    #[b, 3]\n",
    "    next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1,3)\n",
    "    #[b, 1]\n",
    "    over = torch.LongTensor([i[4] for i in samples]).reshape(-1,1)\n",
    "\n",
    "    return state, action, reward, next_state, over\n",
    "\n",
    "\n",
    "state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "state, action, reward, next_state, over"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.3482],\n",
       "        [2.1216],\n",
       "        [1.8461],\n",
       "        [0.0336],\n",
       "        [1.0106],\n",
       "        [0.8619],\n",
       "        [2.2778],\n",
       "        [1.6464],\n",
       "        [0.3536],\n",
       "        [0.3854],\n",
       "        [1.4646],\n",
       "        [0.3995],\n",
       "        [1.1381],\n",
       "        [0.6683],\n",
       "        [1.3003],\n",
       "        [0.5936],\n",
       "        [0.6968],\n",
       "        [1.9040],\n",
       "        [1.7803],\n",
       "        [1.5385],\n",
       "        [0.4234],\n",
       "        [0.6414],\n",
       "        [0.9386],\n",
       "        [1.2212],\n",
       "        [0.6246],\n",
       "        [0.5944],\n",
       "        [1.3460],\n",
       "        [1.0578],\n",
       "        [1.0729],\n",
       "        [1.9386],\n",
       "        [1.3513],\n",
       "        [0.6363],\n",
       "        [0.4319],\n",
       "        [1.1284],\n",
       "        [2.0690],\n",
       "        [0.4771],\n",
       "        [1.0184],\n",
       "        [1.6113],\n",
       "        [0.3482],\n",
       "        [0.4543],\n",
       "        [0.9518],\n",
       "        [0.6348],\n",
       "        [0.3712],\n",
       "        [0.3383],\n",
       "        [1.9124],\n",
       "        [1.1618],\n",
       "        [0.5907],\n",
       "        [2.2666],\n",
       "        [0.9128],\n",
       "        [0.6872],\n",
       "        [0.8802],\n",
       "        [2.0100],\n",
       "        [0.8621],\n",
       "        [2.0878],\n",
       "        [0.8663],\n",
       "        [1.2256],\n",
       "        [1.1280],\n",
       "        [1.2534],\n",
       "        [1.3792],\n",
       "        [0.5059],\n",
       "        [2.0006],\n",
       "        [1.2163],\n",
       "        [1.9388],\n",
       "        [2.0435]], grad_fn=<GatherBackward0>)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_value(state, action):\n",
    "    #使用状态计算出动作的logits\n",
    "    #[b, 3] -> [b, 11]\n",
    "    value = model(state)\n",
    "\n",
    "    #根据实际使用的action取出每一个值\n",
    "    #这个值就是模型评估的在该状态下,执行动作的分数\n",
    "    #在执行动作前,显然并不知道会得到的反馈和next_state\n",
    "    #所以这里不能也不需要考虑next_state和reward\n",
    "    #[b, 11] -> [b, 1]\n",
    "    value = value.gather(dim=1, index=action)\n",
    "\n",
    "    return value\n",
    "\n",
    "\n",
    "get_value(state, action)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[  0.2056],\n",
       "        [-10.2711],\n",
       "        [ -5.8558],\n",
       "        [  0.2924],\n",
       "        [ -4.5447],\n",
       "        [ -1.1903],\n",
       "        [-13.3051],\n",
       "        [ -5.6812],\n",
       "        [  0.1429],\n",
       "        [ -1.4129],\n",
       "        [ -3.5749],\n",
       "        [  0.1193],\n",
       "        [ -2.6486],\n",
       "        [ -2.2494],\n",
       "        [ -5.7656],\n",
       "        [ -0.0203],\n",
       "        [ -2.7104],\n",
       "        [-11.4393],\n",
       "        [ -5.3856],\n",
       "        [ -4.6503],\n",
       "        [ -1.6544],\n",
       "        [ -0.8321],\n",
       "        [ -1.4544],\n",
       "        [-10.1244],\n",
       "        [ -1.1674],\n",
       "        [ -0.3931],\n",
       "        [-10.4399],\n",
       "        [ -2.2642],\n",
       "        [ -1.9273],\n",
       "        [ -6.5915],\n",
       "        [-12.6613],\n",
       "        [ -0.3181],\n",
       "        [  0.3437],\n",
       "        [ -2.1161],\n",
       "        [ -7.6605],\n",
       "        [ -1.4785],\n",
       "        [ -1.7157],\n",
       "        [ -4.3509],\n",
       "        [  0.2365],\n",
       "        [  0.1370],\n",
       "        [ -4.3021],\n",
       "        [ -1.1392],\n",
       "        [  0.2496],\n",
       "        [ -0.2212],\n",
       "        [ -7.7588],\n",
       "        [ -3.7485],\n",
       "        [ -1.7886],\n",
       "        [-12.5877],\n",
       "        [ -1.3500],\n",
       "        [ -0.7750],\n",
       "        [ -1.2404],\n",
       "        [ -7.1121],\n",
       "        [ -1.2353],\n",
       "        [ -9.7294],\n",
       "        [ -3.3617],\n",
       "        [ -3.0604],\n",
       "        [ -7.9138],\n",
       "        [ -2.6380],\n",
       "        [ -3.1741],\n",
       "        [ -0.3806],\n",
       "        [ -8.5584],\n",
       "        [ -5.1252],\n",
       "        [ -9.6294],\n",
       "        [ -9.0987]])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_target(reward, next_state, over):\n",
    "    #上面已经把模型认为的状态下执行动作的分数给评估出来了\n",
    "    #下面使用next_state和reward计算真实的分数\n",
    "    #针对一个状态,它到底应该多少分,可以使用以往模型积累的经验评估\n",
    "    #这也是没办法的办法,因为显然没有精确解,这里使用延迟更新的next_model评估\n",
    "\n",
    "    #使用next_state计算下一个状态的分数\n",
    "    #[b, 3] -> [b, 11]\n",
    "    with torch.no_grad():\n",
    "        target = next_model(next_state)\n",
    "    \"\"\"以下是DQN和Double DQN的主要区别\"\"\"\n",
    "    \"\"\"为了防止DNQTarget值过高估计的问题\"\"\"\n",
    "    #取所有动作中分数最大的\n",
    "    #[b, 11] -> [b, 1]\n",
    "    # target = target.max(dim=1)[0]\n",
    "    # target = target.reshape(-1, 1)\n",
    "    \"\"\"这里先用model计算下一个状态分数中，最高动作分对应的索引\"\"\"\n",
    "    with torch.no_grad():\n",
    "        model_target = model(next_state)\n",
    "    model_target = model_target.max(dim=1)[1]\n",
    "    model_target = model_target.reshape(-1,1)\n",
    "\n",
    "    target = target.gather(dim=1,index=model_target)\n",
    "\n",
    "    \"\"\"以上的DNQ和Double DQN的主要区别\"\"\"\n",
    "    #下一个状态的分数乘以一个系数,相当于权重\n",
    "    target = 0.98*target +(1-over)*reward\n",
    "\n",
    "\n",
    "    return target\n",
    "\n",
    "\n",
    "get_target(reward, next_state, over)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-1238.8870605711936"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "\n",
    "\n",
    "def test(play):\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #记录反馈值的和,这个值越大越好\n",
    "    reward_sum = 0\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        _, action_continuous = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        state, reward, over, _ = env.step([action_continuous])\n",
    "        reward_sum += reward\n",
    "\n",
    "        #打印动画\n",
    "        if play: \n",
    "            display.clear_output(wait=True)\n",
    "            show()\n",
    "\n",
    "    return reward_sum\n",
    "\n",
    "\n",
    "test(play=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "第 0 轮——————————分数：-507.57936676822663\n",
      "第 20 轮——————————分数：-229.09667556743437\n",
      "第 40 轮——————————分数：-160.12458985000126\n",
      "第 60 轮——————————分数：-459.54501520012843\n",
      "第 80 轮——————————分数：-225.75658475966875\n",
      "第 100 轮——————————分数：-250.7507012864684\n",
      "第 120 轮——————————分数：-416.3873553012484\n",
      "第 140 轮——————————分数：-235.19321269220376\n",
      "第 160 轮——————————分数：-195.57979628830918\n",
      "第 180 轮——————————分数：-284.82261890849867\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)\n",
    "    loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    #训练N次\n",
    "    for epoch in range(200):\n",
    "        #更新N条数据\n",
    "        update_data()\n",
    "\n",
    "        #每次更新过数据后,学习N次\n",
    "        for i in range(201):\n",
    "            #采样一批数据\n",
    "            state, action, reward, next_state, over = get_sample()\n",
    "\n",
    "            #计算一批样本的value和target\n",
    "            value = get_value(state, action)\n",
    "            target = get_target(reward, next_state, over)\n",
    "\n",
    "            #更新参数\n",
    "            loss = loss_fn(value, target)\n",
    "            optimizer.zero_grad()#清空梯度防止累积\n",
    "            loss.backward()#计算梯度\n",
    "            optimizer.step()#更新参数\n",
    "\n",
    "            #把model的参数复制给next_model\n",
    "            if (i + 1) % 50 == 0:\n",
    "                next_model.load_state_dict(model.state_dict())\n",
    "\n",
    "        if epoch % 20 == 0:\n",
    "            test_result = sum([test(play=False) for _ in range(20)]) / 20\n",
    "            print(f'第 {epoch} 轮——————————分数：{test_result}')\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGiCAYAAABd6zmYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAoN0lEQVR4nO3df3RU5YH/8c9MJjP5xUxCIIkpRNhK1ZQfq6Aw1a1Yo1FZqzWtrsejVF09YvSIdD0rXcWt7Vmo7rdWu4o9q6u2p5YedhetCNp8o4YqkR+RaACNbaElApMgkJkkkklm5vn+YZmvkV8J3Mk8E96vc+YcufeZO89cQ97MzJ17XcYYIwAALORO9wQAADgSIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsFbaIvXEE09owoQJysnJ0cyZM7V+/fp0TQUAYKm0ROo3v/mNFixYoAcffFDvvvuupk2bpurqanV0dKRjOgAAS7nScYLZmTNn6pxzztF//Md/SJISiYTGjx+vu+66S/fdd99wTwcAYCnPcD9gX1+fmpqatHDhwuQyt9utqqoqNTY2HvY+0WhU0Wg0+edEIqF9+/apuLhYLpcr5XMGADjLGKOuri6Vl5fL7T7ym3rDHqlPPvlE8XhcpaWlA5aXlpbqww8/POx9Fi9erB/84AfDMT0AwDBqa2vTuHHjjrh+2CN1PBYuXKgFCxYk/xwOh1VRUaG2tjb5/f40zgwAcDwikYjGjx+vUaNGHXXcsEdqzJgxysrKUnt7+4Dl7e3tKisrO+x9fD6ffD7fIcv9fj+RAoAMdqyPbIb96D6v16vp06ervr4+uSyRSKi+vl7BYHC4pwMAsFha3u5bsGCB5s6dqxkzZujcc8/VT3/6U/X09Oimm25Kx3QAAJZKS6SuvfZa7dmzR4sWLVIoFNLf/u3f6tVXXz3kYAoAwMktLd+TOlGRSESBQEDhcJjPpAAgAw329zjn7gMAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgrSFHas2aNbriiitUXl4ul8ulF198ccB6Y4wWLVqkU045Rbm5uaqqqtIf/vCHAWP27dun66+/Xn6/X4WFhbrlllvU3d19Qk8EADDyDDlSPT09mjZtmp544onDrn/44Yf1+OOP66mnntK6deuUn5+v6upq9fb2Jsdcf/312rJli+rq6rRy5UqtWbNGt9122/E/CwDAyGROgCSzYsWK5J8TiYQpKyszjzzySHJZZ2en8fl85te//rUxxpitW7caSWbDhg3JMatXrzYul8vs3LlzUI8bDoeNJBMOh09k+gCANBns73FHP5Pavn27QqGQqqqqkssCgYBmzpypxsZGSVJjY6MKCws1Y8aM5Jiqqiq53W6tW7fusNuNRqOKRCIDbgCAkc/RSIVCIUlSaWnpgOWlpaXJdaFQSCUlJQPWezwejR49OjnmixYvXqxAIJC8jR8/3slpAwAslRFH9y1cuFDhcDh5a2trS/eUAADDwNFIlZWVSZLa29sHLG9vb0+uKysrU0dHx4D1sVhM+/btS475Ip/PJ7/fP+AGABj5HI3UxIkTVVZWpvr6+uSySCSidevWKRgMSpKCwaA6OzvV1NSUHPP6668rkUho5syZTk4HAJDhPEO9Q3d3t/74xz8m/7x9+3Y1Nzdr9OjRqqio0Pz58/WjH/1IkyZN0sSJE/XAAw+ovLxcV111lSTpzDPP1KWXXqpbb71VTz31lPr7+3XnnXfqH/7hH1ReXu7YEwMAjABDPWzwjTfeMJIOuc2dO9cY89lh6A888IApLS01Pp/PXHTRRaa1tXXANvbu3Wuuu+46U1BQYPx+v7nppptMV1eX44cuAgDsNNjf4y5jjEljI49LJBJRIBBQOBzm8ykAyECD/T2eEUf3AQBOTkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGCtIZ9gFsDxiba3q+ejj9S3Z48S0ajcOTnyjh2r/NNPl2/s2HRPD7ASkQJSyBijWCSivfX12tfQ8FmgDhyQicfl8njkzs2Vd+xYFV94oYovvFBZo0bJ5XKle9qANYgUkEK9H3+sHT//ubpbWqQvnMvZxGKKd3XpQFeXPt6+XZFNmzT+ttuUwyVrgCQ+kwJSpHfnTn389NPqfv/9QwJ1CGMUefddffzMM+rdvXt4JghkACIFpECsu1vtK1Yo8t57Q7pf+N131fHSS4r39KRoZkBmIVKAw4wxijQ365Pf/U5KJIZ253hce1atUtfWrcrAS70BjiNSgMNMf786Vq48oW10vPSSTDzu0IyAzEWkAIeZeFyftrae0Da6P/xw6K/CgBGISAGWSvT3p3sKQNoRKcBGxijW2ZnuWQBpR6QAh23/938/8c+TjFHf3r3OTAjIYEQKcFgsEjnhbZhEQt1btjgwGyCzESnARomEujdvTvcsgLQjUoDDsgoK0j0FYMQgUoDDRs+eLbn5qwU4gb9JgMOyi4oc2Y5JJJSIxRzZFpCpiBTgMO+YMZIDl9tI9PUp3t3twIyAzEWkAId5AgE5cUWoRF+fYkQKJzkiBTjM5dDnUYneXsW7uhzZFpCpiBRgqb72dr4rhZMekQIsxuU6cLIjUoDDXB6Piv7u7xzbHqHCyYxIAQ5zud3KPfVUR7ZlYrFjX3oeGMGIFOA0l0seh74rFevqkuGSHTiJESnAaS6Xsv1+RzYVj0S4rhROakQKcJjL5XLstEgHPv5Y8Z4eR7YFZCIiBVjswLZtivFdKZzEiBSQAq6sLLmys9M9DSDjESkgBXLKy1VQWZnuaQAZj0gBKeD2+ZSVn+/IthLRKN+VwkmLSAEp4Pb5HLv4Yf++fY5sB8hERApIAZfXq6zcXEe21bd3ryPbATIRkQJSwOVyOXJNKUna9+abnHUCJy0iBVguGgoRKZy0iBSQIr7SUrl9vnRPA8hoRApIkby/+RvHjvADTlZECkgRTyDgzBd6jVEiGj3x7QAZiEgBKeLx++V2KFL9+/ef+HaADESkgBTJysuTKyvrhLdjEgn1d3ae+ISADESkgBRxOXQmdCUS+nTbNme2BWQYIgVYzsRi6mpuTvc0gLQgUkAKubzedE8ByGhECkihsXPmpHsKQEYjUkAKeUePdmQ7Jh7nMvI4KREpIIWyi4sdOYdforeXy8jjpESkgBTyBAKObCdOpHCSIlJACrkcOhN6X3u7Pt2+3ZFtAZmESAEZIN7ToxhnncBJiEgBKeRyu+U75ZR0TwPIWEQKSCFXdrZGTZ3q2PYM15XCSYZIASnkcruV7dBh6IlolIsf4qRDpIBUcruVXVTkyKb6IxGZWMyRbQGZgkgBqeRyOXbhwziRwkloSJFavHixzjnnHI0aNUolJSW66qqr1NraOmBMb2+vamtrVVxcrIKCAtXU1Ki9vX3AmB07dmjOnDnKy8tTSUmJ7r33XsX4y4cRyOVyOXY29N5QiIsf4qQzpL89DQ0Nqq2t1TvvvKO6ujr19/frkksuUc/nvmR4zz336OWXX9by5cvV0NCgXbt26eqrr06uj8fjmjNnjvr6+rR27Vo9//zzeu6557Ro0SLnnhUwAvVs3apYd3e6pwEMK5c5gcOF9uzZo5KSEjU0NOjrX/+6wuGwxo4dqxdeeEHf/va3JUkffvihzjzzTDU2NmrWrFlavXq1/v7v/167du1SaWmpJOmpp57SP//zP2vPnj3yDuKs0ZFIRIFAQOFwWH6//3inDwyL/Y2N2rZkiSMHPVQ+8YRyx493YFZAeg329/gJvQ8RDoclSaP/evRSU1OT+vv7VVVVlRxzxhlnqKKiQo2NjZKkxsZGTZkyJRkoSaqurlYkEtGWLVsO+zjRaFSRSGTADcgUeRMnqqCy0rHtcRg6TibHHalEIqH58+frvPPO0+TJkyVJoVBIXq9XhYWFA8aWlpYqFAolx3w+UAfXH1x3OIsXL1YgEEjexvMvSWQQd26uYwdPJHp7HdkOkCmOO1K1tbXavHmzli1b5uR8DmvhwoUKh8PJW1tbW8ofE3BKVk6OY5Hq37vXke0AmcJzPHe68847tXLlSq1Zs0bjxo1LLi8rK1NfX586OzsHvJpqb29XWVlZcsz69esHbO/g0X8Hx3yRz+eTz+c7nqkCaefyeuV26Oc3umePI9sBMsWQXkkZY3TnnXdqxYoVev311zVx4sQB66dPn67s7GzV19cnl7W2tmrHjh0KBoOSpGAwqJaWFnV0dCTH1NXVye/3q9LB9+0BWzh1JnRJCn/hH3jASDekV1K1tbV64YUX9NJLL2nUqFHJz5ACgYByc3MVCAR0yy23aMGCBRo9erT8fr/uuusuBYNBzZo1S5J0ySWXqLKyUjfccIMefvhhhUIh3X///aqtreXVEnAMB/7853RPARhWQ4rU0qVLJUmzZ88esPzZZ5/Vd7/7XUnSo48+KrfbrZqaGkWjUVVXV+vJJ59Mjs3KytLKlSs1b948BYNB5efna+7cuXrooYdO7JkAFvOOGSOXx8MZI4AhOqHvSaUL35NCpun64ANt+7d/U+yvX9s4Xp5AQFN/8QtH30IE0mFYvicFYHCyAwG5PMd1nNJAxnAYOk4qRAoYBp5AQK6srBPejkkk1M8VenESIVLAMMjKzXUmUvG4+js7T3xCQIYgUsAwcOpM6CYeV98RzswCjERECsggpq9P4aamdE8DGDZEChguHJEHDBmRAoZJ6Te/6di2MvCbI8BxIVLAMPEe4dyUQ2X6+2X6+x3ZFmA7B764AWAwsv963bXPO/iKKPHX/z74+sglKcvlOuyXduO9vYofOCD3IC4QCmQ6IgUME8+oUcn/NsaoNx5XR2+vtnd36719+/SnSER7env1aSymC8rKdHdlpbIOE6nEgQOK9/QoOxAYzukDaUGkgGHWl0ho0969WhMK6f39+xU3Rn9TUKAzCwv1dzk5yvd4VJGff8RTH0Xb2xXdvVs55eXDPHNg+BEpYJgkjNG2ri4t275dm/buVUV+vm788pf1lUBARV6v8jyeI77F93mxzk4ufoiTBpEChkFXV5dWvPyyHmpqUklOjr731a9qxpgxybfzOGEscHhECkixvXv36sknn9Qvf/EL1Zx3nmbHYhqbk0OYgEEgUkAKdXV1acmSJVq9erV+vHixpkUi2v/iiye83Xhvr0wi4djplgBb8RMOpEhPT49+8pOfqK6uTj/96U/1zauuUkFpqSPb7u/slInHHdkWYDNeSQEpEIvFtGLFCr3wwgt6+OGHNXv2bLmzspSVm+vM9sPhzyKVne3I9gBb8UoKcJgxRps3b9bPfvYz/eM//qMuv/xyeTweRz+D2v/WW4p3dzu2PcBWRApwWF9fn5588kmNGzdO3/3ud5Wdglc7iQMHZBIJx7cL2IZIAQ4yxuj999/X22+/rRtvvFHFxcUD1ntLSuQdOzZNswMyD5ECHBSLxfTUU09p6tSpuvjii+X+wtF32cXFyh4zxpkH40zoOAkQKcBBW7du1YYNG3TjjTcqLy/vkPVZeXmOHTwR//RTLtmBEY9IAQ4xxuidd96R1+vV+eeff9gxWTk5jkXqWKdGMsYozmHqyHBECnBIV1eXmpqa9PWvf/2wr6IkyeX1yuXQJTb6Pvnk6Ov7+vTMM88oFos58nhAOhApwCH79u3TRx99pPPPP/+Qz6IOGsxh6Dt7erSyrU2/3rZN/3fXLvUc4QKHXS0tR91Oc3OznnnmGW3btu3YkwcsxZd5AYdEIhFFIhGdeuqpx3V/Y4y2d3frwU2b9OfubvXG4/JnZ2tyUZH+/ZxzlP2F8PV8+OFRt7V8+XJt27ZN69at06RJkzhXIDISr6QAh/T09MjlcqmgoOCoQfD4/dJhXmlt6+7WrW+/rQ/CYR2Ix2Ukhfv79XZHh+5et057e3sHPZdt27bpjTfe0L59+7RmzRp1dXUdz1MC0o5IAQ6JxWLy+Xzy+XxHHVcYDB724Imfbtmi8BHe2lv/ySeq27VrUPMwxujtt9/W9u3bJUlr167V7t27ORIQGYlIAQ7yer3Kyso66pjswkK5jjHmREQiEf3+97+X1+uVx+PRzp079f7776fs8YBUIlKAg/r7+5U4xumKsgsLpRRFyhijvXv36ktf+pJ+/OMf68wzz9QPf/hD/fGPfzzmvAAbESnAIW63W9FoVH19fUcfl5ursZdeesjyOePHK/sIn2VNKCjQ1NGjBzWPMWPG6K677pLf71dubq5qamp0zTXXHPGIQ8BmHN0HOCQ3N1fxeFyf/vVMEEc6eMLlcqn0yisVDYW07403ksury8slST967z31xeNKSMpyuVTo9er/nHOOTi0oGLCd0pqaw27b7/dLkv7yl78oLy9PhYWFKv/rtoFMQ6QAhwQCAeXl5WnXrl2aOnXqUce6c3NVFAwqvHGj4n898s7lcqm6vFzj8vK08uOPtbe3VxMKCnTtxIkq/sLBGL5TTlFRMHjEEB68XMipp56qXIfOcAGkA5ECHFJUVKQJEyZo48aNqq6uPuph6C6XS/6zztLYyy5T+//8T/Iquy6XS5OLijS5qOiI9/UUFan8hhs+O5T9CPbv368//elPqqmp4ftRyGi8SQ04JBAIaNq0aXrzzTcHdSoit8+n0m99S6O/8Q25PIP792JWQYFOufZaFZ577hGPEDTGqKWlRR0dHbrwwguH9BwA2xApwCFut1vnn3++QqGQ3nvvvUHdJysvT+NuvlmlNTXylpQccZwrK0u5EyZo/K23auxll8l9lPP/xWIxvfXWW/rSl76k0047bcjPA7AJb/cBDnG5XJo2bZpOO+00vfDCCzrrrLPkOcYrJJfLJU9+vk759rflnzpV+9euVfeWLYqGQkpEo8oqKFDOuHEKzJihwIwZyq2oOObbd+3t7aqvr9d3vvOdYz4+YDt+ggEH5eXl6eabb9aDDz6ojRs3aubMmYP6TMjt86lg8mTlf+Urih84IBOLySQScmVlyZ2dLXdentyDCE4ikdBrr72maDSqCy+8kMPOkfH4CQYc5HK5dPHFF6uiokLPPffckM6Z53K55Pb5lF1YKO+YMfKVlMhbXCyP3z+oQElSW1ubli5dqquvvlpf/vKXOWgCGY9IAQ7LycnRvHnztHbtWtXX1w/LOfOMMerp6dFjjz2m/Px8XXPNNcrOzk754wKpRqQAh7ndbs2ePVvf+c539K//+q/avHlzSk9JZIxRf3+/fvnLX2rVqlV66KGHNG7cuJQ9HjCciBSQArm5ubrjjjt01llnaf78+WppaUlZqIwx+u1vf6ulS5fqjjvu0Ne+9jXe5sOIQaSAFCkuLtaiRYuUk5Ojf/qnf9JHH33k6Ft/xhjFYjG99NJL+uEPf6grrrhCc+fO5W0+jChECkihiRMn6rHHHtPYsWN13XXX6dVXX1V3d/cJx8oYo46ODj3yyCP6/ve/r2uuuUb3339/8rx9wEjhMhl4JbRIJKJAIKBwOMxfSmSE3bt369FHH9Urr7yiyy+/XDfddJPOPPPM43pb7tNPP1VjY6P+8z//U62trZo3b57mzp17zIstAjYZ7O9xIgUMk0gkoldeeUVPPPGE+vv7dcUVV+jaa6/VqaeeKo/HkwzW58N18K9nIpFQf3+/NmzYoGeeeUaNjY06++yz9b3vfU9TpkwhUMg4RAqwjDFGxhhFIhH96le/0m9+8xuFQiF95Stf0ezZszVlyhSNHTtWPp9PLpcredmPXbt2qbGxUW+99Zb279+vr371q7rpppt04YUXKjc3l4MkkJGIFGAxY4za29vV2Niod955R5s3b9bOnTsVjUaTZ4k4eDRgQUGBJk2apClTpmjWrFmaOXOm8vLy0jl94IQRKSADGGPU19en/fv3KxKJqKenR/39/TLGyO12y+fzKT8/X0VFRfL7/ZyLDyPGYH+P8xMPpJHL5ZLP51NZWZnKysrSPR3AOhyCDgCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsNKVJLly7V1KlT5ff75ff7FQwGtXr16uT63t5e1dbWqri4WAUFBaqpqVF7e/uAbezYsUNz5sxRXl6eSkpKdO+99yoWiznzbAAAI8qQIjVu3DgtWbJETU1N2rhxo77xjW/oyiuv1JYtWyRJ99xzj15++WUtX75cDQ0N2rVrl66++urk/ePxuObMmaO+vj6tXbtWzz//vJ577jktWrTI2WcFABgZzAkqKioyTz/9tOns7DTZ2dlm+fLlyXUffPCBkWQaGxuNMcasWrXKuN1uEwqFkmOWLl1q/H6/iUajg37McDhsJJlwOHyi0wcApMFgf48f92dS8Xhcy5YtU09Pj4LBoJqamtTf36+qqqrkmDPOOEMVFRVqbGyUJDU2NmrKlCkqLS1NjqmurlYkEkm+GjucaDSqSCQy4AYAGPmGHKmWlhYVFBTI5/Pp9ttv14oVK1RZWalQKCSv16vCwsIB40tLSxUKhSRJoVBoQKAOrj+47kgWL16sQCCQvI0fP36o0wYAZKAhR+r0009Xc3Oz1q1bp3nz5mnu3LnaunVrKuaWtHDhQoXD4eStra0tpY8HALCDZ6h38Hq9Ou200yRJ06dP14YNG/TYY4/p2muvVV9fnzo7Owe8mmpvb1dZWZkkqaysTOvXrx+wvYNH/x0cczg+n08+n2+oUwUAZLgT/p5UIpFQNBrV9OnTlZ2drfr6+uS61tZW7dixQ8FgUJIUDAbV0tKijo6O5Ji6ujr5/X5VVlae6FQAACPMkF5JLVy4UJdddpkqKirU1dWlF154QW+++aZee+01BQIB3XLLLVqwYIFGjx4tv9+vu+66S8FgULNmzZIkXXLJJaqsrNQNN9yghx9+WKFQSPfff79qa2t5pQQAOMSQItXR0aEbb7xRu3fvViAQ0NSpU/Xaa6/p4osvliQ9+uijcrvdqqmpUTQaVXV1tZ588snk/bOysrRy5UrNmzdPwWBQ+fn5mjt3rh566CFnnxUAYERwGWNMuicxVJFIRIFAQOFwWH6/P93TAQAM0WB/j3PuPgCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWOqFILVmyRC6XS/Pnz08u6+3tVW1trYqLi1VQUKCamhq1t7cPuN+OHTs0Z84c5eXlqaSkRPfee69isdiJTAUAMAIdd6Q2bNign//855o6deqA5ffcc49efvllLV++XA0NDdq1a5euvvrq5Pp4PK45c+aor69Pa9eu1fPPP6/nnntOixYtOv5nAQAYmcxx6OrqMpMmTTJ1dXXmggsuMHfffbcxxpjOzk6TnZ1tli9fnhz7wQcfGEmmsbHRGGPMqlWrjNvtNqFQKDlm6dKlxu/3m2g0OqjHD4fDRpIJh8PHM30AQJoN9vf4cb2Sqq2t1Zw5c1RVVTVgeVNTk/r7+wcsP+OMM1RRUaHGxkZJUmNjo6ZMmaLS0tLkmOrqakUiEW3ZsuWwjxeNRhWJRAbcAAAjn2eod1i2bJneffddbdiw4ZB1oVBIXq9XhYWFA5aXlpYqFAolx3w+UAfXH1x3OIsXL9YPfvCDoU4VAJDhhvRKqq2tTXfffbd+9atfKScnJ1VzOsTChQsVDoeTt7a2tmF7bABA+gwpUk1NTero6NDZZ58tj8cjj8ejhoYGPf744/J4PCotLVVfX586OzsH3K+9vV1lZWWSpLKyskOO9jv454Njvsjn88nv9w+4AQBGviFF6qKLLlJLS4uam5uTtxkzZuj6669P/nd2drbq6+uT92ltbdWOHTsUDAYlScFgUC0tLero6EiOqaurk9/vV2VlpUNPCwAwEgzpM6lRo0Zp8uTJA5bl5+eruLg4ufyWW27RggULNHr0aPn9ft11110KBoOaNWuWJOmSSy5RZWWlbrjhBj388MMKhUK6//77VVtbK5/P59DTAgCMBEM+cOJYHn30UbndbtXU1Cgajaq6ulpPPvlkcn1WVpZWrlypefPmKRgMKj8/X3PnztVDDz3k9FQAABnOZYwx6Z7EUEUiEQUCAYXDYT6fAoAMNNjf45y7DwBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLU+6J3A8jDGSpEgkkuaZAACOx8Hf3wd/nx9JRkZq7969kqTx48eneSYAgBPR1dWlQCBwxPUZGanRo0dLknbs2HHUJ3eyi0QiGj9+vNra2uT3+9M9HWuxnwaH/TQ47KfBMcaoq6tL5eXlRx2XkZFyuz/7KC0QCPBDMAh+v5/9NAjsp8FhPw0O++nYBvMigwMnAADWIlIAAGtlZKR8Pp8efPBB+Xy+dE/FauynwWE/DQ77aXDYT85ymWMd/wcAQJpk5CspAMDJgUgBAKxFpAAA1iJSAABrZWSknnjiCU2YMEE5OTmaOXOm1q9fn+4pDas1a9boiiuuUHl5uVwul1588cUB640xWrRokU455RTl5uaqqqpKf/jDHwaM2bdvn66//nr5/X4VFhbqlltuUXd39zA+i9RavHixzjnnHI0aNUolJSW66qqr1NraOmBMb2+vamtrVVxcrIKCAtXU1Ki9vX3AmB07dmjOnDnKy8tTSUmJ7r33XsViseF8Kim1dOlSTZ06NfnF02AwqNWrVyfXs48Ob8mSJXK5XJo/f35yGfsqRUyGWbZsmfF6vea//uu/zJYtW8ytt95qCgsLTXt7e7qnNmxWrVpl/uVf/sX87//+r5FkVqxYMWD9kiVLTCAQMC+++KJ57733zDe/+U0zceJEc+DAgeSYSy+91EybNs2888475ve//7057bTTzHXXXTfMzyR1qqurzbPPPms2b95smpubzeWXX24qKipMd3d3csztt99uxo8fb+rr683GjRvNrFmzzNe+9rXk+lgsZiZPnmyqqqrMpk2bzKpVq8yYMWPMwoUL0/GUUuK3v/2teeWVV8xHH31kWltbzfe//32TnZ1tNm/ebIxhHx3O+vXrzYQJE8zUqVPN3XffnVzOvkqNjIvUueeea2pra5N/jsfjpry83CxevDiNs0qfL0YqkUiYsrIy88gjjySXdXZ2Gp/PZ379618bY4zZunWrkWQ2bNiQHLN69WrjcrnMzp07h23uw6mjo8NIMg0NDcaYz/ZJdna2Wb58eXLMBx98YCSZxsZGY8xn/xhwu90mFAolxyxdutT4/X4TjUaH9wkMo6KiIvP000+zjw6jq6vLTJo0ydTV1ZkLLrggGSn2Vepk1Nt9fX19ampqUlVVVXKZ2+1WVVWVGhsb0zgze2zfvl2hUGjAPgoEApo5c2ZyHzU2NqqwsFAzZsxIjqmqqpLb7da6deuGfc7DIRwOS/r/JyduampSf3//gP10xhlnqKKiYsB+mjJlikpLS5NjqqurFYlEtGXLlmGc/fCIx+NatmyZenp6FAwG2UeHUVtbqzlz5gzYJxI/T6mUUSeY/eSTTxSPxwf8T5ak0tJSffjhh2malV1CoZAkHXYfHVwXCoVUUlIyYL3H49Ho0aOTY0aSRCKh+fPn67zzztPkyZMlfbYPvF6vCgsLB4z94n463H48uG6kaGlpUTAYVG9vrwoKCrRixQpVVlaqubmZffQ5y5Yt07vvvqsNGzYcso6fp9TJqEgBx6O2tlabN2/WW2+9le6pWOn0009Xc3OzwuGw/vu//1tz585VQ0NDuqdllba2Nt19992qq6tTTk5OuqdzUsmot/vGjBmjrKysQ46YaW9vV1lZWZpmZZeD++Fo+6isrEwdHR0D1sdiMe3bt2/E7cc777xTK1eu1BtvvKFx48Yll5eVlamvr0+dnZ0Dxn9xPx1uPx5cN1J4vV6ddtppmj59uhYvXqxp06bpscceYx99TlNTkzo6OnT22WfL4/HI4/GooaFBjz/+uDwej0pLS9lXKZJRkfJ6vZo+fbrq6+uTyxKJhOrr6xUMBtM4M3tMnDhRZWVlA/ZRJBLRunXrkvsoGAyqs7NTTU1NyTGvv/66EomEZs6cOexzTgVjjO68806tWLFCr7/+uiZOnDhg/fTp05WdnT1gP7W2tmrHjh0D9lNLS8uAoNfV1cnv96uysnJ4nkgaJBIJRaNR9tHnXHTRRWppaVFzc3PyNmPGDF1//fXJ/2ZfpUi6j9wYqmXLlhmfz2eee+45s3XrVnPbbbeZwsLCAUfMjHRdXV1m06ZNZtOmTUaS+clPfmI2bdpk/vKXvxhjPjsEvbCw0Lz00kvm/fffN1deeeVhD0E/66yzzLp168xbb71lJk2aNKIOQZ83b54JBALmzTffNLt3707ePv300+SY22+/3VRUVJjXX3/dbNy40QSDQRMMBpPrDx4yfMkll5jm5mbz6quvmrFjx46oQ4bvu+8+09DQYLZv327ef/99c9999xmXy2V+97vfGWPYR0fz+aP7jGFfpUrGRcoYY372s5+ZiooK4/V6zbnnnmveeeeddE9pWL3xxhtG0iG3uXPnGmM+Owz9gQceMKWlpcbn85mLLrrItLa2DtjG3r17zXXXXWcKCgqM3+83N910k+nq6krDs0mNw+0fSebZZ59Njjlw4IC54447TFFRkcnLyzPf+ta3zO7duwds589//rO57LLLTG5urhkzZoz53ve+Z/r7+4f52aTOzTffbE499VTj9XrN2LFjzUUXXZQMlDHso6P5YqTYV6nBpToAANbKqM+kAAAnFyIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCs9f8AjoYg9lEz8cAAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "-130.84996326830878"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test(play=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 3],\n",
       "        [5, 6],\n",
       "        [2, 6]])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## torch.gather 函数\n",
    "\n",
    "a=torch.tensor([\n",
    "    [1,2,3],\n",
    "    [2,5,6],\n",
    "    [2,6,5],\n",
    "    [7,5,6]\n",
    "])\n",
    "\n",
    "#我现在想选，第一行的第一个和第三个元素\n",
    "#第二行的第二个和第三个元素\n",
    "#第三行的第一个和第二个元素\n",
    "#可以构造如下索引矩阵\n",
    "indexTensor=torch.tensor([\n",
    "    [0,2],\n",
    "    [1,2],\n",
    "    [0,1]\n",
    "])\n",
    "torch.gather(a,dim=1,index=indexTensor)\n",
    "#dim=1  ==> 竖着取元素，\n",
    "#dim=0 ==> 按列选元素,每一行的数字个数不能超过被索引的列数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 6, 6],\n",
       "        [2, 5, 3]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "##我想取第一列的第1、3、4个元素，第二列的第3，2，1个元素\n",
    "indexten = torch.tensor([\n",
    "    [0,2,3],\n",
    "    [2,1,0]\n",
    "])\n",
    "b=torch.gather(a,dim=0,index=indexten)\n",
    "b"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Gym",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
